{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `NameSupply`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "\n",
    "from tvm import relay\n",
    "from tvm.ir import GlobalVar, structural_equal\n",
    "from tvm.ir.supply import NameSupply\n",
    "from tvm.ir.supply import GlobalVarSupply\n",
    "\n",
    "\n",
    "def test_name_supply():\n",
    "    name_supply = NameSupply(\"prefix\")\n",
    "    name_supply.reserve_name(\"test\")\n",
    "\n",
    "    assert name_supply.contains_name(\"test\")\n",
    "    assert name_supply.fresh_name(\"test\") == \"prefix_test_1\"\n",
    "    assert name_supply.contains_name(\"test_1\")\n",
    "    assert not name_supply.contains_name(\"test_1\", False)\n",
    "    assert not name_supply.contains_name(\"test_2\")\n",
    "\n",
    "\n",
    "def test_global_var_supply_from_none():\n",
    "    var_supply = GlobalVarSupply()\n",
    "    global_var = GlobalVar(\"test\")\n",
    "    var_supply.reserve_global(global_var)\n",
    "\n",
    "    assert structural_equal(var_supply.unique_global_for(\"test\"), global_var)\n",
    "    assert not structural_equal(var_supply.fresh_global(\"test\"), global_var)\n",
    "\n",
    "\n",
    "def test_global_var_supply_from_name_supply():\n",
    "    name_supply = NameSupply(\"prefix\")\n",
    "    var_supply = GlobalVarSupply(name_supply)\n",
    "    global_var = GlobalVar(\"test\")\n",
    "    var_supply.reserve_global(global_var)\n",
    "\n",
    "    assert structural_equal(var_supply.unique_global_for(\"test\", False), global_var)\n",
    "    assert not structural_equal(var_supply.unique_global_for(\"test\"), global_var)\n",
    "\n",
    "\n",
    "def test_global_var_supply_from_ir_mod():\n",
    "    x = relay.var(\"x\")\n",
    "    y = relay.var(\"y\")\n",
    "    mod = tvm.IRModule()\n",
    "    global_var = GlobalVar(\"test\")\n",
    "    mod[global_var] = relay.Function([x, y], relay.add(x, y))\n",
    "    var_supply = GlobalVarSupply(mod)\n",
    "\n",
    "    second_global_var = var_supply.fresh_global(\"test\", False)\n",
    "\n",
    "    assert structural_equal(var_supply.unique_global_for(\"test\", False), global_var)\n",
    "    assert not structural_equal(var_supply.unique_global_for(\"test\"), global_var)\n",
    "    assert not structural_equal(second_global_var, global_var)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
